2025.7.30 Simulated Annealing法による2D-Rastrigin関数の最適化
参考:Rastrigin 関数
code:p.py
import numpy as np
import matplotlib.pyplot as plt
def rastrigin_function2d(x, y):
A, n = 10, 2
ret = A * n
ret += x**2 - A * np.cos(2 * np.pi * x)
ret += y**2 - A * np.cos(2 * np.pi * y)
return ret
def create_x():
x = np.random.uniform(x_range0, x_range1)
y = np.random.uniform(f_range0, f_range1)
return np.array((x, y))
def solve_sa(func, rho, iter_max):
global x_hist, f_hist, T_hist, chk_hist, x_escape_hist, f_escape_hist
T = 1.
x_now = create_x()
f_now = func(x_now0, x_now1)
x_best, f_best = x_now, f_now
iter_count = 0
x_hist.append(x_now)
f_hist.append(f_now)
while True:
T *= rho
x_new = create_x()
f_new = func(x_new0, x_new1)
n_now = np.linalg.norm(x_now)
n_new = np.linalg.norm(x_new)
metro_rand = np.random.rand()
metro_exp = np.exp(- (f_new - f_now)/n_now / T)
if (f_new < f_now) or (metro_rand < metro_exp):
if f_new >= f_now:
print('escape at {}'.format(iter_count))
x_escape_hist.append(x_now)
f_escape_hist.append(f_now)
x_now, f_now = x_new, f_new
if f_new < f_best:
x_best, f_best = x_new, f_new
if iter_count == iter_max:
print('計算回数が上限に達したので計算を打ち切った', T, iter_count)
break
if T <= eps:
print('Tが小さくなったので計算を打ち切った', T, iter_count)
break
x_hist.append(x_now)
f_hist.append(f_now)
iter_count += 1
T_hist.append(T)
chk_hist.append(metro_exp)
return x_best, f_best
############################
rho = 0.9999
eps = 10e-8
iter_max = 10000
x_range = -5.12, 5.12
f_range = -5.12, 5.12
x_hist, f_hist = [], []
x_best_hist, f_best_hist = [], []
T_hist, chk_hist = [], []
x_escape_hist, f_escape_hist = [], []
func = rastrigin_function2d
x_best, f_best = solve_sa(func, rho, iter_max)
print('x best', x_best)
print('f best', f_best)
x_hist = np.stack(x_hist)
f_hist = np.stack(f_hist)
# plot
xx = np.linspace(x_range0, x_range1, 50)
yy = np.linspace(f_range0, f_range1, 50)
xxx, yyy = np.meshgrid(xx, yy)
n = 2 # 2-dim
zzz = func(xxx, yyy)
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
# ax.plot_surface(xxx, yyy, zzz)
ax.plot_wireframe(xxx, yyy, zzz)
ax.scatter(x_hist:,0, x_hist:,1, f_hist, '*', color='red')
plt.show()